import os
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from sklearn.decomposition import PCA
from sklearn.manifold import MDS, TSNE
import umap

def process_directory(dir_path, output_dir):
    """
    Process a specified directory, read data, perform dimensionality reduction, 
    generate plots, and save them.

    Parameters:
    - dir_path: Path to the directory.
    - output_dir: Path to the output directory (same as the input directory).
    """
    # Get the directory name for naming images
    dir_name = os.path.basename(dir_path.rstrip(os.sep))
    print(f"Processing directory: {dir_name}")

    # Define data file paths
    sample_kmer_path = os.path.join(dir_path, "sample_kmer.csv")
    ld_data_path = os.path.join(dir_path, "ld_data.csv")
    sample_labels_path = os.path.join(dir_path, "sample_labels.csv")
    consensus_path = os.path.join(dir_path, "consensus.csv")
    distance_matrix_path = os.path.join(dir_path, "pairwise_distances_init.csv")  # Assumed file name

    # Check if data files exist
    required_files = [sample_kmer_path, ld_data_path, sample_labels_path, consensus_path, distance_matrix_path]
    for file_path in required_files:
        if not os.path.isfile(file_path):
            warnings.warn(f"Missing file: {file_path}. Skipping this directory.")
            return

    # ------------------------------
    # Read data
    # ------------------------------
    try:
        # Read sample_kmer.csv
        data_strings = pd.read_csv(sample_kmer_path, header=None, names=["string"])

        # Read and transpose ld_data.csv
        data_coordinates = pd.read_csv(ld_data_path, header=None).transpose()

        if data_coordinates.shape[1] < 2:
            warnings.warn(f"ld_data.csv in {dir_name} does not have enough columns. Skipping.")
            return

        data_coordinates.columns = ["x_coord", "y_coord"]

        # Read sample_labels.csv
        data_labels = pd.read_csv(sample_labels_path, header=None, names=["label"])

        # Read motif sequences
        with open(consensus_path, 'r') as file:
            motif_sequences = [line.strip() for line in file]

        # Read pairwise distance matrix
        distance_matrix = pd.read_csv(distance_matrix_path, header=None).values  # Convert to numpy array

    except Exception as e:
        warnings.warn(f"Error reading files in {dir_name}: {e}. Skipping.")
        return

    # ------------------------------
    # Merge data
    # ------------------------------
    try:
        data_ge = pd.concat([data_strings, data_coordinates, data_labels], axis=1)
        data_ge['label_numeric'] = data_ge['label'].astype(int) + 1
    except Exception as e:
        warnings.warn(f"Error merging data in {dir_name}: {e}. Skipping.")
        return

    # ------------------------------
    # Process Motif Sequences
    # ------------------------------
    try:
        motif_numbers = list(range(1, len(motif_sequences) + 1))
        motif_df = pd.DataFrame({
            'label_numeric': motif_numbers,
            'motif_sequence': motif_sequences
        })

        data_ge = data_ge.merge(motif_df, on="label_numeric", how="left")

        # Create label_text
        max_label = data_ge['label_numeric'].max()
        data_ge['label_text'] = data_ge.apply(
            lambda row: "Random" if row['label_numeric'] == max_label else row['motif_sequence'],
            axis=1
        )

        # Set label_text as an ordered category
        unique_labels_ordered = motif_df.sort_values('label_numeric')['motif_sequence'].tolist()
        label_levels = unique_labels_ordered + ["Random"]
        data_ge['label_text'] = pd.Categorical(
            data_ge['label_text'],
            categories=label_levels,
            ordered=True
        )
    except Exception as e:
        warnings.warn(f"Error processing motifs in {dir_name}: {e}. Skipping.")
        return

    # ------------------------------
    # Define Color Palette
    # ------------------------------
    try:
        # Use Seaborn's "Set2" color palette
        n_colors = 8  # Including "Random"

        # Get colors from the palette
        palette = sns.color_palette("Set2", n_colors)

        # Assign the last color to "Random"
        color_mapping = {"Random": palette[-1]}

        # Assign colors for the motif sequences
        labels_for_palette = [label for label in label_levels if label != "Random"]
        n_clusters = len(labels_for_palette)

        available_colors = palette[:-1]  # First 7 colors

        if n_clusters > len(available_colors):
            warnings.warn(
                f"Number of motif categories ({n_clusters}) exceeds available colors ({len(available_colors)}). Colors will be reused."
            )

        colors_for_clusters = available_colors * (n_clusters // len(available_colors) + 1)
        colors_for_clusters = colors_for_clusters[:n_clusters]

        # Create color mapping dictionary for motif categories
        colors_for_clusters_dict = dict(zip(labels_for_palette, colors_for_clusters))

        # Update color mapping
        color_mapping.update(colors_for_clusters_dict)
    except Exception as e:
        warnings.warn(f"Error defining color palette in {dir_name}: {e}. Skipping.")
        return

    # ------------------------------
    # Plot Original ld_data.csv
    # ------------------------------
    try:
        plt.figure(figsize=(10, 8))
        sns.scatterplot(
            data=data_ge,
            x="x_coord",
            y="y_coord",
            hue="label_text",
            palette=color_mapping,
            alpha=0.6,
            edgecolor=None
        )

        # Customize plot
        plt.title(f"KMAP LD Plot - {dir_name}", fontsize=14)
        plt.xlabel("")
        plt.ylabel("")
        sns.despine(left=True, bottom=True)
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])
        plt.box(False)

        plt.legend(
            title="Motif Sequence",
            title_fontsize=10,
            fontsize=8,
            loc='best',
            frameon=False,
            markerscale=1.5
        )

        plt.tight_layout()

        # Save plot
        image_filename = f"{dir_name}_KMAP.pdf"
        image_path = os.path.join(output_dir, image_filename)
        plt.savefig(image_path, dpi=600, bbox_inches='tight')
        plt.close()
        print(f"Saved plot to {image_path}")
    except Exception as e:
        warnings.warn(f"Error plotting original ld_data.csv in {dir_name}: {e}. Skipping.")
        return

    # ------------------------------
    # Dimensionality Reduction and Plotting
    # ------------------------------
    try:
        if not (distance_matrix.shape[0] == distance_matrix.shape[1]):
            warnings.warn(f"Distance matrix in {dir_name} is not square. Skipping.")
            return

        if not (distance_matrix.diagonal() == 0).all():
            warnings.warn(f"Distance matrix in {dir_name} does not have zeros on the diagonal.")

        dim_reduction_methods = {
            "PCA": PCA(n_components=2),
            "MDS": MDS(n_components=2, dissimilarity='precomputed', random_state=42),
            "UMAP": umap.UMAP(n_components=2, metric='precomputed', random_state=42),
            "tSNE": TSNE(n_components=2, metric='precomputed', random_state=42)
        }

        for method_name, method in dim_reduction_methods.items():
            print(f"Applying {method_name} on {dir_name}")

            if method_name == "PCA":
                ld_data = method.fit_transform(distance_matrix)
            else:
                ld_data = method.fit_transform(distance_matrix)

            ld_df = pd.DataFrame(ld_data, columns=["x_coord", "y_coord"])
            ld_df['label_text'] = data_ge['label_text']

            plt.figure(figsize=(10, 8))
            sns.scatterplot(
                data=ld_df,
                x="x_coord",
                y="y_coord",
                hue="label_text",
                palette=color_mapping,
                alpha=0.4,
                edgecolor=None
            )

            plt.title(f"{method_name} Plot - {dir_name}", fontsize=14)
            plt.xlabel("")
            plt.ylabel("")
            sns.despine(left=True, bottom=True)
            plt.grid(False)
            plt.xticks([])
            plt.yticks([])
            plt.box(False)

            plt.legend(
                title="Motif Sequence",
                title_fontsize=10,
                fontsize=8,
                loc='best',
                frameon=False,
                markerscale=1.5
            )

            plt.tight_layout()

            image_filename = f"{dir_name}_{method_name}.pdf"
            image_path = os.path.join(output_dir, image_filename)
            plt.savefig(image_path, dpi=600, bbox_inches='tight')
            plt.close()
            print(f"Saved plot to {image_path}")

    except Exception as e:
        warnings.warn(f"Error during dimensionality reduction or plotting in {dir_name}: {e}. Skipping.")
        return

def main():
    """
    Main function to process a single input directory and save images to the same directory.
    """
    parser = argparse.ArgumentParser(description="Visualize data with dimensionality reduction techniques.")
    parser.add_argument(
        "--out_dir",
        type=str,
        required=True,
        help="Path to the output directory containing data files."
    )
    args = parser.parse_args()

    input_dir = args.out_dir
    output_dir = input_dir

    if not os.path.isdir(input_dir):
        raise ValueError(f"Input directory '{input_dir}' does not exist or is not a directory.")

    os.makedirs(output_dir, exist_ok=True)

    process_directory(input_dir, output_dir)
    print("Processing complete!")

if __name__ == "__main__":
    main()
